from huggingface_hub import HfApi
import os
import subprocess
import time

def download_file_aria2(url, destination_folder, file_name):
    full_path = os.path.join(destination_folder, file_name)
    download_command = f"aria2c --console-log-level=error -c -x 16 -s 16 -k 1M -t 10 -d {destination_folder} -o '{file_name}' '{url}'"
    subprocess.run(download_command, shell=True, stdout=subprocess.DEVNULL)

def download_hf_repo_files(repo_list, base_path, hf_endpoint):
    hf_api = HfApi(endpoint=hf_endpoint)

    for repo in repo_list:
        repo_name = repo.split("/")[-1]
        # 修改此处: 对于特定仓库 "AnimateAnyone"，直接使用 base_path 作为下载目录
        folder = base_path if repo_name == "AnimateAnyone" else os.path.join(base_path, repo_name)
        os.makedirs(folder, exist_ok=True)

        print(f"Fetching files for {repo}...")
        while True:
            try:
                files = hf_api.list_repo_files(repo)
                files = [fi for fi in files if fi not in ['.gitattributes', 'README.md']]
                break
            except Exception as e:
                print(f"Error fetching files for {repo}: {e}")
                time.sleep(5)

        for fi in files:
            download_url = f"{hf_endpoint}/{repo}/resolve/main/{fi}"
            print(f"Downloading {fi}...")
            download_file_aria2(download_url, folder, fi)
            print(f"Downloaded {fi}")

def main():
    # 定义仓库列表和基础路径
    repo_list = ["patrolli/AnimateAnyone", "bdsqlsz/stable-diffusion-v1-5", "bdsqlsz/image_encoder"]
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_path = os.path.join(script_dir, "pretrained_weights")
    hf_endpoint = "https://hf-mirror.com"

    # 下载 Hugging Face 仓库文件
    download_hf_repo_files(repo_list, base_path, hf_endpoint)

if __name__ == "__main__":
    main()
